-- -- -- -- 卷积的矩阵乘法实现：
-- -- -- --   https://www.cnblogs.com/shine-lee/p/10775831.html
-- -- -- --   https://blog.csdn.net/qq_40263477/article/details/104979609
-- -- 卷积的 fft 实现
-- --   https://blog.csdn.net/qq_37527608/article/details/120922348
-- -- 卷积与矩阵乘法算法调优：
-- --      -- --  https://blog.csdn.net/lizhengx/article/details/83246833
-- --      -- -- 卷积求导
-- --      -- --  https://blog.csdn.net/wangjianyu0115/article/details/62222301
-- --      
-- --      -- set search_path to sm_sc;
-- drop function if exists sm_sc.fv_d_conv_2d_dloss_dindepdt_2(float[], float[], int[], int[2], int[4], float);
create or replace function sm_sc.fv_d_conv_2d_dloss_dindepdt_2
(
  i_array              float[]                                     ,  -- 原矩阵(卷积第一目参数)
  i_dloss_ddepdt       float[]                                     ,  -- 即已求出的损失函数对 y 的导数矩阵
  i_window_len         int[]                                               ,  -- 卷积核窗口矩阵规格，如果是三维、四维，那么长度为三维或四维
  i_stride             int[2]              default  array[1, 1]             ,  -- 纵向与横向步长
  i_padding            int[4]              default  array[0, 0, 0, 0]       ,  -- 上下左右补齐行数/列数
  i_padding_value      float      default  0.0                        -- 补齐填充元素值
)
returns float[]
as
$$
declare 
  v_array_len_heigh        int   :=   array_length(i_array, array_ndims(i_array) - 1) ;
  v_array_len_width        int   :=   array_length(i_array, array_ndims(i_array))     ;
  v_dloss_ddepdt_len_heigh int   :=   array_length(i_dloss_ddepdt, array_ndims(i_dloss_ddepdt) - 1) ;
  v_dloss_ddepdt_len_width int   :=   array_length(i_dloss_ddepdt, array_ndims(i_dloss_ddepdt)) ;
  v_array_len_heigh_ex     int   :=   coalesce(i_padding[1], 0) + v_array_len_heigh + coalesce(i_padding[2], 0);     --   新背景矩阵高
  v_array_len_width_ex     int   :=   coalesce(i_padding[3], 0) + v_array_len_width + coalesce(i_padding[4], 0);     --   新背景矩阵宽
  v_window_len             int[2] :=   i_window_len[array_length(i_window_len, 1) - 1 : array_length(i_window_len, 1)];
  v_ret                    float[];
begin
  -- 审计
  if current_setting('pg4ml._v_is_debug_check', true) = '1'
  then
    -- 审计二维长度
    if array_ndims(i_array) not in (2, 3, 4) or array_ndims(i_dloss_ddepdt) not in (2, 3, 4) or array_ndims(i_dloss_ddepdt) <> array_ndims(i_array)
    then 
      raise exception 'unsupport ndims of i_array and i_dloss_ddepdt.';
    elsif array_ndims(i_dloss_ddepdt) = 3 and array_length(i_dloss_ddepdt, 1) <> array_length(i_array, 1)
      or (array_ndims(i_dloss_ddepdt) = 4 and (array_length(i_dloss_ddepdt, 1) <> array_length(i_array, 1) or array_length(i_dloss_ddepdt, 2) <> array_length(i_array, 2)))
    then 
      raise exception 'unmatch length at 1d of 3d or 1d / 2d of 4d';
    elsif (v_array_len_heigh_ex - v_window_len[1]) % i_stride[1] <> 0
    then 
      raise exception 'imperfect window at heigh.';
    elsif (v_array_len_width_ex - v_window_len[2]) % i_stride[2] <> 0
    then 
      raise exception 'imperfect window at width.';
    elsif v_dloss_ddepdt_len_heigh <> (v_array_len_heigh_ex - v_window_len[1]) / i_stride[1] + 1
      or v_dloss_ddepdt_len_width <> (v_array_len_width_ex - v_window_len[2]) / i_stride[2] + 1
    then
      raise exception 'unmatch length between y and dloss/dy.';
    end if;
  end if;
  
  if array_ndims(i_dloss_ddepdt) = 2
  then 
    i_array := 
      sm_sc.fv_augmented
      (
        i_array, 
        array[-i_padding[1] + 1, -i_padding[3] + 1], 
        array[v_array_len_heigh + i_padding[2], v_array_len_width + i_padding[4]], 
        i_padding_value
      );
    return 
      |~~|
      (
        select 
          sm_sc.fa_mx_sum
          (
            i_array[col_a_y : col_a_y + v_window_len[1] - 1][col_a_x : col_a_x + v_window_len[2] - 1] 
              *` i_dloss_ddepdt[(col_a_y - 1) / i_stride[1] + 1][(col_a_x - 1) / i_stride[2] + 1]
          )
        from generate_series(1, v_array_len_heigh_ex - v_window_len[1] + i_stride[1], i_stride[1]) tb_a_y(col_a_y)
          , generate_series(1, v_array_len_width_ex - v_window_len[2] + i_stride[2], i_stride[2]) tb_a_x(col_a_x)
      );
  
  elsif array_ndims(i_dloss_ddepdt) = 3
  then 
    v_ret := 
    (
      select 
        array_agg -- sm_sc.fa_mx_sum 
        (
          sm_sc.fv_d_conv_2d_dloss_dindepdt_2
          (
            sm_sc.fv_mx_slice_3d_2_2d
            (
              i_array[a_cur_y : a_cur_y]
            , 1
            )
          , sm_sc.fv_mx_slice_3d_2_2d
            (
              i_dloss_ddepdt[a_cur_y : a_cur_y]
            , 1
            )
          , v_window_len     
          , i_stride           
          , i_padding       
          , i_padding_value
          )
          order by a_cur_y
        )
      from generate_series(1, array_length(i_dloss_ddepdt, 1)) tb_a_cur_y(a_cur_y)
    );
    if array_length(i_window_len, 1) = 2
    then 
      v_ret := 
        sm_sc.fv_mx_slice_3d_2_2d
        (
          sm_sc.fv_aggr_slice_sum(v_ret, array[array_length(v_ret, 1), 1, 1])
        , 1
        );
    end if;
    return v_ret;
    
  elsif array_ndims(i_dloss_ddepdt) = 4
  then 
    v_ret := 
    (
      with 
      cte_agg_x as 
      (
        select 
          a_cur_y, 
          array_agg    -- sm_sc.fa_mx_sum 
          (
            sm_sc.fv_d_conv_2d_dloss_dindepdt_2
            (
              sm_sc.fv_mx_slice_4d_2_2d
              (
                i_array[a_cur_y : a_cur_y][a_cur_x : a_cur_x][ : ][ : ]
              , array[1, 2]
              , array[1, 1]
              )
            , sm_sc.fv_mx_slice_4d_2_2d
              (
                i_dloss_ddepdt[a_cur_y : a_cur_y][a_cur_x : a_cur_x][ : ][ : ]
              , array[1, 2]
              , array[1, 1]
              )
            , v_window_len     
            , i_stride           
            , i_padding       
            , i_padding_value
            )
            order by a_cur_x
          ) as a_agg_x
        from generate_series(1, array_length(i_dloss_ddepdt, 1)) tb_a_cur_y(a_cur_y)
          , generate_series(1, array_length(i_dloss_ddepdt, 2)) tb_a_cur_x(a_cur_x)
        group by a_cur_y
      )
      select 
        array_agg(a_agg_x order by a_cur_y)
      from cte_agg_x
    );
    if array_length(i_window_len, 1) = 2
    then 
      v_ret := 
        sm_sc.fv_mx_slice_4d_2_2d
        (
          sm_sc.fv_aggr_slice_sum(v_ret, array[array_length(v_ret, 1), array_length(v_ret, 2), 1, 1])
        , array[1, 2]
        , array[1, 1]
        );
    end if;
    return v_ret;
  end if;
end
$$
language plpgsql stable
parallel safe
cost 100;
-- -- set search_path to sm_sc;
-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
--   (
--     array[[1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--         , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--         , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--         , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--         , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--          ] :: float[]
--    , array[[1.1, 1.1, 1.1], [1.1, 1.1, 1.1]]
--    , array[3, 3]
--    , array[2, 2]
--   );

-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
--   (
--     array[[1,2,3,4,5,6]
--         , [10,20,30,40,50,60]
--         , [100,200,300,400,500,600]
--         , [-1,-2,-3,-4,-5,-6]
--         , [-10,-20,-30,-40,-50,-60]
--          ] :: float[]
--    , array[[1.1, 1.1, 1.1], [1.1, 2.1, 1.1], [1.1, 1.1, 1.1]]
--    , array[3, 3]
--    , array[2, 2]
--    , array[1, 1, 1, 0]
--    , 0
--   );

-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
--   (
--     array
--       [
--         [
--           [1,2,3,4,5,6]
--         , [10,20,30,40,50,60]
--         , [100,200,300,400,500,600]
--         , [-1,-2,-3,-4,-5,-6]
--         , [-10,-20,-30,-40,-50,-60]
--         ]
--       , [
--           [-1,2,-3,4,5,6]
--         , [10,-20,30,40,50,-60]
--         , [100,200,-300,400,500,600]
--         , [-1,2,-3,-4,-5,-6]
--         , [-10,-20,30,-40,50,-60]
--         ]
--       ]
--    , array[[[1.1, 1.1, -1.1], [1.1, -2.1, 1.1], [-1.1, 1.1, 1.1]],[[-1.1, 1.1, 1.1], [1.1, -2.1, 1.1], [1.1, 1.1, -1.1]]]
--    , array[3, 3]     --  array[2, 3, 3]
--    , array[2, 2]
--    , array[1, 1, 1, 0]
--    , 0
--   ) :: decimal[] ~=` 3;

-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
--   (
--    array
--    [
--      [
--        [
--          [1,2,3,4,5,6]
--        , [10,20,30,40,50,60]
--        , [100,200,300,400,500,600]
--        , [-1,-2,-3,-4,-5,-6]
--        , [-10,-20,-30,-40,-50,-60]
--        ]
--      , [
--          [-1,2,-3,4,5,6]
--        , [10,-20,30,40,50,-60]
--        , [100,200,-300,400,500,600]
--        , [-1,2,-3,-4,-5,-6]
--        , [-10,-20,30,-40,50,-60]
--        ]
--      ]
--    , [
--        [
--          [1,2,3,-4,5,6]
--        , [10,20,30,40,50,60]
--        , [100,-200,300,400,500,600]
--        , [-1,2,-3,-4,5,-6]
--        , [10,20,-30,-40,-50,60]
--        ]
--      , [
--          [1,2,-3,-4,5,6]
--        , [10,-20,-30,40,50,60]
--        , [100,200,-300,-400,-500,600]
--        , [-1,2,-3,-4,5,-6]
--        , [-10,-20,30,-40,50,-60]
--        ]
--      ]
--    ]  
--    , array
--      [
--        [[[1.1, 1.1, -1.1], [-1.1, -2.1, -1.1], [-1.1, 1.1, 1.1]]
--        ,[[-1.1, 1.1, 1.1], [1.1, -2.1, 1.1], [1.1, 1.1, -1.1]]]
--      , [[[1.1, 1.1, -1.1], [1.1, -2.1, 1.1], [-1.1, 1.1, 1.1]]
--        ,[[-1.1, 1.1, 1.1], [-1.1, -2.1, -1.1], [1.1, 1.1, -1.1]]]
--      ]
--    , array[3, 3]    -- array[2, 2, 3, 3]
--    , array[2, 2]
--    , array[1, 1, 1, 0]
--    , 0
--   ) :: decimal[] ~=` 3;